{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DAY4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数组的创建"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NumPy 数组基础笔记\n",
    "\n",
    "### 1. 理解数组的维度 (Dimensions)\n",
    "\n",
    "NumPy 数组的**维度 (Dimension)** 或称为 **轴 (Axis)** 的概念，与我们日常理解的维度非常相似。\n",
    "\n",
    "*   **直观判断:** 数组的维度层数通常可以通过打印输出时**中括号 `[]` 的嵌套层数**来初步确定：\n",
    "    *   一层 `[]`: **一维 (1D)** 数组。\n",
    "    *   两层 `[]`: **二维 (2D)** 数组。\n",
    "    *   三层 `[]`: **三维 (3D)** 数组，依此类推。\n",
    "\n",
    "### 2. NumPy 数组与深度学习 Tensor 的关系\n",
    "\n",
    "在后续进行频繁的数学运算时，尤其是在深度学习领域，对 NumPy 数组的理解非常有帮助，因为 PyTorch 或 TensorFlow 中的 **Tensor** 张量本质上可以视为**支持 GPU 加速**和**自动微分**的 NumPy 数组。掌握 NumPy 的基本操作，能极大地降低学习 Tensor 的门槛。关于 NumPy 更深入的性质，我们留待后续探讨。\n",
    "\n",
    "### 3. 一维数组 (1D Array)\n",
    "\n",
    "一维数组在结构上与 Python 中的列表（List）非常相似。它们的主要区别在于：\n",
    "\n",
    "*   **打印输出格式:** 当使用 `print()` 函数输出时：\n",
    "    *   NumPy 一维数组的元素之间默认使用**空格**分隔。\n",
    "    *   Python 列表的元素之间使用**逗号**分隔。\n",
    "\n",
    "    *   **示例 (一维数组输出):**\n",
    "        ```\n",
    "        [7 5 3 9]\n",
    "        ```\n",
    "\n",
    "### 4. 二维数组 (2D Array)\n",
    "\n",
    "二维数组可以被看作是“数组的数组”或者一个矩阵。其结构由两个主要维度决定：\n",
    "\n",
    "*   **行数:** 代表整个二维数组中**包含多少个一维数组**。\n",
    "*   **列数:** 代表**每个一维数组（也就是每一行）中包含多少个元素**。\n",
    "\n",
    "值得注意的是，二维数组**不一定**是正方形（即行数等于列数），它可以是任意的 `n * m` 形状，其中 `n` 是行数，`m` 是列数。\n",
    "\n",
    "### 5. 数组的创建\n",
    "\n",
    "NumPy 的 `array()` 函数非常灵活，可以接受各种“序列型”对象作为输入参数来创建数组。这意味着你可以将 Python 的**列表 (List)**、**元组 (Tuple)**，甚至其他的 NumPy **数组**等数据结构直接传递给 `np.array()` 来创建新的 NumPy 数组。\n",
    "\n",
    "---\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数组的简单创建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 2  4  6  8 10 12]\n",
      "[[ 2  4  6]\n",
      " [ 8 10 12]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "a = np.array([2,4,6,8,10,12]) # 创建一个一维数组\n",
    "b = np.array([[2,4,6],[8,10,12]]) # 创建一个二维数组\n",
    "print(a)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[7, 5, 3, 9]\n",
      "[7 5 3 9]\n"
     ]
    }
   ],
   "source": [
    "# 分清楚列表和数组的区别\n",
    "print([7, 5, 3, 9])  # 输出: [7, 5, 3, 9]（逗号分隔）\n",
    "print(np.array([7, 5, 3, 9]))  # 输出: [7 5 3 9]（空格分隔）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6,)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.shape # numpy中可以用shape来查看数组的形状"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0., 0., 0.],\n",
       "       [0., 0., 0.]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "zeros = np.zeros((2, 3)) # 创建一个2行3列的全零矩阵\n",
    "zeros"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1., 1., 1.])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ones = np.ones((3,))  # 创建一个形状为(3,)的全1数组\n",
    "ones"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 2, 3, 4, 5, 6, 7, 8, 9])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 顺序数组的创建\n",
    "arange = np.arange(1, 10) # 创建一个从1到10的数组\n",
    "arange"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数组的随机化创建\n",
    "1. 在后续深度学习中，我们经常需要对数据进行随机化处理，以确保模型的泛化能力。\n",
    "2. 为了测试很多函数的性能，往往需要随机化生成很多数据。\n",
    "\n",
    "\n",
    "\n",
    "- NumPy随机数生成方法对比\n",
    "\n",
    "| 方法                     | 作用范围/分布       | 记忆口诀               | 典型应用场景           | 示例                     |\n",
    "|--------------------------|-------------------|----------------------|----------------------|--------------------------|\n",
    "| `np.random.randint(a,b)` | [a,b]整数         | \"int\"结尾表示整数      | 生成随机索引/标签      | `np.random.randint(1,10)` → 7 |\n",
    "| `random.random()`        | [0,1)浮点数       | 纯\"random\"最基础      | 简单概率模拟          | `random.random()` → 0.548 |\n",
    "| `np.random.rand()`       | [0,1)均匀分布      | \"rand\"=random+uniform | 蒙特卡洛模拟          | `np.random.rand(3)` → [0.2,0.5,0.8] |\n",
    "| `np.random.randn()`      | 标准正态分布(μ=0,σ=1) | 多一个\"n\"=normal      | 数据标准化/深度学习初始化 | `np.random.randn(2,2)` → [[-0.1,1.2],[0.5,-0.3]] |\n",
    "\n",
    "- 记忆技巧：\n",
    "1. **看结尾**：\n",
    "   - \"int\" → 整数\n",
    "   - \"n\" → 正态(normal)\n",
    "   \n",
    "2. **看前缀**：\n",
    "   - 纯\"random\" → Python基础随机\n",
    "   - \"np.random\" → NumPy增强版\n",
    "\n",
    "3. **功能差异**：\n",
    "   - `rand()`和`random()`都是均匀分布，但`rand()`能直接生成数组\n",
    "   - `randn()`生成的数据会有正有负，其他方法都是非负数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.40396838, 0.67658735],\n",
       "       [0.11142565, 0.39165721]])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 创建一个2*2的随机数组c，区间为[0,1)\n",
    "c = np.random.rand(2, 2)  \n",
    "c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "所有成绩: [80.  73.6 81.5 90.2 72.7 72.7 90.8 82.7 70.3 80.4]\n",
      "最高分: 90.8 (第6个学生)\n",
      "最低分: 70.3 (第8个学生)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "np.random.seed(42)  # 设置随机种子以确保结果可重复\n",
    "\n",
    "# 生成10个语文成绩（正态分布，均值75，标准差10）\n",
    "chinese_scores = np.random.normal(75, 10, 10).round(1)\n",
    "\n",
    "# 找出最高分和最低分及其索引\n",
    "max_score = np.max(chinese_scores)\n",
    "max_index = np.argmax(chinese_scores)\n",
    "min_score = np.min(chinese_scores)\n",
    "min_index = np.argmin(chinese_scores)\n",
    "\n",
    "print(f\"所有成绩: {chinese_scores}\")\n",
    "print(f\"最高分: {max_score} (第{max_index}个学生)\")\n",
    "print(f\"最低分: {min_score} (第{min_index}个学生)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数组的遍历"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "scores = np.array([5, 9, 9, 11, 11, 13, 15, 19])\n",
    "scores += 1 # 学习一下这个写法，等价于 scores = scores + 1\n",
    "sum = 0\n",
    "for i in scores: # 遍历数组中的每个元素\n",
    "    sum += i   \n",
    "print(sum)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数组的运算\n",
    "\n",
    "1. 矩阵乘法：需要满足第一个矩阵的列数等于第二个矩阵的行数，和线代的矩阵乘法算法相同。\n",
    "2. 矩阵点乘：需要满足两个矩阵的行数和列数相同，然后两个矩阵对应位置的元素相乘。\n",
    "3. 矩阵转置：将矩阵的行和列互换。\n",
    "4. 矩阵求逆：需要满足矩阵是方阵且行列式不为0，然后使用伴随矩阵除以行列式得到逆矩阵。\n",
    "5. 矩阵求行列式：需要满足矩阵是方阵，然后使用代数余子式展开计算行列式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1 2]\n",
      " [3 4]\n",
      " [5 6]]\n",
      "[[ 7  8]\n",
      " [ 9 10]\n",
      " [11 12]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "a = np.array([[1, 2], [3, 4], [5, 6]])\n",
    "b = np.array([[7, 8], [9, 10], [11, 12]])\n",
    "print(a)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 8 10]\n",
      " [12 14]\n",
      " [16 18]]\n"
     ]
    }
   ],
   "source": [
    "print(a + b) # 计算两个数组的和"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-6 -6]\n",
      " [-6 -6]\n",
      " [-6 -6]]\n"
     ]
    }
   ],
   "source": [
    "print(a - b) # 计算两个数组的差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.14285714 0.25      ]\n",
      " [0.33333333 0.4       ]\n",
      " [0.45454545 0.5       ]]\n"
     ]
    }
   ],
   "source": [
    "print(a / b) # 计算两个数组的除法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 7, 16],\n",
       "       [27, 40],\n",
       "       [55, 72]])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a * b # 矩阵点乘，ipynb文件中不使用print()函数会自动输出结果，这是ipynb文件的特性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 23,  29,  35],\n",
       "       [ 53,  67,  81],\n",
       "       [ 83, 105, 127]])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a @ b.T # 矩阵乘法,3*2的矩阵和2*3的矩阵相乘，得到3*3的矩阵"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数组的索引"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 一维数组索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "arr1d = np.arange(10)  # 数组: [0 1 2 3 4 5 6 7 8 9]\n",
    "arr1d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 1. 取出数组的第一个元素。\n",
    "arr1d[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取出数组的最后一个元素。-1表示倒数第一个元素。\n",
    "arr1d[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([3, 5, 8])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 3. 取出数组中索引为 3, 5, 8 的元素。\n",
    "# 使用整数数组进行索引，可以一次性取出多个元素。语法是 arr1d[[index1, index2, ...]]。\n",
    "arr1d[[3, 5, 8]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 3, 4, 5])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 切片取出索引\n",
    "arr1d[2:6] # 取出索引为2到5的元素（不包括索引6的元素，取左不取右）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取出数组中从头到索引 5 (不包含 5) 的元素。\n",
    "# 使用切片 slice [:stop]\n",
    "arr1d[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([4, 5, 6, 7, 8, 9])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取出数组中从索引 4 到结尾的元素。\n",
    "# 使用切片 slice [start:]\n",
    "arr1d[4:]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取出全部元素\n",
    "arr1d[:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 2, 4, 6, 8])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 7取出数组中所有偶数索引对应的元素 (即索引 0, 2, 4, 6, 8)。\n",
    "# 使用带步长的切片 slice [start:stop:step]\n",
    "arr1d[::2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二维数组索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 1,  2,  3,  4],\n",
       "       [ 5,  6,  7,  8],\n",
       "       [ 9, 10, 11, 12],\n",
       "       [13, 14, 15, 16]])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 数组:\n",
    "arr2d = np.array([[1, 2, 3, 4],\n",
    "                  [5, 6, 7, 8],\n",
    "                  [9, 10, 11, 12],\n",
    "                  [13, 14, 15, 16]])\n",
    "arr2d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "索引顺序：在二维数组 arr2d 里，第一个索引值代表行，第二个索引值代表列。比如 arr2d[i, j] ，i 是行索引，j 是列索引。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([5, 6, 7, 8])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取出第 1 行 (索引为 1) 的所有元素。\n",
    "#\n",
    "# 使用索引 arr[row_index, :] 或 arr[row_index]\n",
    "arr2d[1, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([5, 6, 7, 8])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 也可以省略后面的 :\n",
    "arr2d[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 3,  7, 11, 15])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取出第 2 列 (索引为 2) 的所有元素。\n",
    "# 使用索引 arr[:, column_index]\n",
    "arr2d[:, 2]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "12"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取出位于第 2 行 (索引 2)、第 3 列 (索引 3) 的元素。\n",
    "# 使用 arr[row_index, column_index]\n",
    "arr2d[2, 3]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 1,  2,  3,  4],\n",
       "       [ 9, 10, 11, 12]])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取出由第 0 行和第 2 行组成的新数组。\n",
    "# 使用整数数组作为行索引 arr[[row1, row2, ...], :]\n",
    "arr2d[[0, 2], :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 2,  4],\n",
       "       [ 6,  8],\n",
       "       [10, 12],\n",
       "       [14, 16]])"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取出由第 1 列和第 3 列组成的新数组。\n",
    "# 使用整数数组作为列索引 arr[:, [col1, col2, ...]]\n",
    "arr2d[:, [1, 3]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 6,  7],\n",
       "       [10, 11]])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取出一个 2x2 的子矩阵，包含元素 6, 7, 10, 11。\n",
    "# 使用切片 slice arr[row_start:row_stop, col_start:col_stop]\n",
    "arr2d[1:3, 1:3]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三维数组索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[ 0,  1,  2,  3,  4],\n",
       "        [ 5,  6,  7,  8,  9],\n",
       "        [10, 11, 12, 13, 14],\n",
       "        [15, 16, 17, 18, 19]],\n",
       "\n",
       "       [[20, 21, 22, 23, 24],\n",
       "        [25, 26, 27, 28, 29],\n",
       "        [30, 31, 32, 33, 34],\n",
       "        [35, 36, 37, 38, 39]],\n",
       "\n",
       "       [[40, 41, 42, 43, 44],\n",
       "        [45, 46, 47, 48, 49],\n",
       "        [50, 51, 52, 53, 54],\n",
       "        [55, 56, 57, 58, 59]]])"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "arr3d = np.arange(3 * 4 * 5).reshape((3, 4, 5))\n",
    "arr3d "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[20, 21, 22, 23, 24],\n",
       "       [25, 26, 27, 28, 29],\n",
       "       [30, 31, 32, 33, 34],\n",
       "       [35, 36, 37, 38, 39]])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 选择特定的层\n",
    "# 使用整数数组 [0, 2] 作为第一个维度 (层) 的索引\n",
    "arr3d[1, :, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[20, 21, 22, 23, 24],\n",
       "       [25, 26, 27, 28, 29]])"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "arr3d[1, 0:2, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[22, 23],\n",
       "       [27, 28]])"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "arr3d[1, 0:2, 2:4]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SHAP值的深入理解"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在重新审视下之前的shap数组"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- 1. 默认参数随机森林 (训练集 -> 测试集) ---\n",
      "训练与预测耗时: 0.9712 秒\n",
      "\n",
      "默认随机森林 在测试集上的分类报告：\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.77      0.97      0.86      1059\n",
      "           1       0.79      0.30      0.43       441\n",
      "\n",
      "    accuracy                           0.77      1500\n",
      "   macro avg       0.78      0.63      0.64      1500\n",
      "weighted avg       0.77      0.77      0.73      1500\n",
      "\n",
      "默认随机森林 在测试集上的混淆矩阵：\n",
      "[[1023   36]\n",
      " [ 309  132]]\n"
     ]
    }
   ],
   "source": [
    "# 先运行之前预处理好的代码\n",
    "import pandas as pd\n",
    "import pandas as pd    #用于数据处理和分析，可处理表格数据。\n",
    "import numpy as np     #用于数值计算，提供了高效的数组操作。\n",
    "import matplotlib.pyplot as plt    #用于绘制各种类型的图表\n",
    "import seaborn as sns   #基于matplotlib的高级绘图库，能绘制更美观的统计图形。\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    " \n",
    " # 设置中文字体（解决中文显示问题）\n",
    "plt.rcParams['font.sans-serif'] = ['SimHei']  # Windows系统常用黑体字体\n",
    "plt.rcParams['axes.unicode_minus'] = False    # 正常显示负号\n",
    "data = pd.read_csv('data.csv')    #读取数据\n",
    "\n",
    "\n",
    "# 先筛选字符串变量 \n",
    "discrete_features = data.select_dtypes(include=['object']).columns.tolist()\n",
    "# Home Ownership 标签编码\n",
    "home_ownership_mapping = {\n",
    "    'Own Home': 1,\n",
    "    'Rent': 2,\n",
    "    'Have Mortgage': 3,\n",
    "    'Home Mortgage': 4\n",
    "}\n",
    "data['Home Ownership'] = data['Home Ownership'].map(home_ownership_mapping)\n",
    "\n",
    "# Years in current job 标签编码\n",
    "years_in_job_mapping = {\n",
    "    '< 1 year': 1,\n",
    "    '1 year': 2,\n",
    "    '2 years': 3,\n",
    "    '3 years': 4,\n",
    "    '4 years': 5,\n",
    "    '5 years': 6,\n",
    "    '6 years': 7,\n",
    "    '7 years': 8,\n",
    "    '8 years': 9,\n",
    "    '9 years': 10,\n",
    "    '10+ years': 11\n",
    "}\n",
    "data['Years in current job'] = data['Years in current job'].map(years_in_job_mapping)\n",
    "\n",
    "# Purpose 独热编码，记得需要将bool类型转换为数值\n",
    "data = pd.get_dummies(data, columns=['Purpose'])\n",
    "data2 = pd.read_csv(\"data.csv\") # 重新读取数据，用来做列名对比\n",
    "list_final = [] # 新建一个空列表，用于存放独热编码后新增的特征名\n",
    "for i in data.columns:\n",
    "    if i not in data2.columns:\n",
    "       list_final.append(i) # 这里打印出来的就是独热编码后的特征名\n",
    "for i in list_final:\n",
    "    data[i] = data[i].astype(int) # 这里的i就是独热编码后的特征名\n",
    "\n",
    "\n",
    "\n",
    "# Term 0 - 1 映射\n",
    "term_mapping = {\n",
    "    'Short Term': 0,\n",
    "    'Long Term': 1\n",
    "}\n",
    "data['Term'] = data['Term'].map(term_mapping)\n",
    "data.rename(columns={'Term': 'Long Term'}, inplace=True) # 重命名列\n",
    "continuous_features = data.select_dtypes(include=['int64', 'float64']).columns.tolist()  #把筛选出来的列名转换成列表\n",
    " \n",
    " # 连续特征用中位数补全\n",
    "for feature in continuous_features:     \n",
    "    mode_value = data[feature].mode()[0]            #获取该列的众数。\n",
    "    data[feature].fillna(mode_value, inplace=True)          #用众数填充该列的缺失值，inplace=True表示直接在原数据上修改。\n",
    "\n",
    "# 最开始也说了 很多调参函数自带交叉验证，甚至是必选的参数，你如果想要不交叉反而实现起来会麻烦很多\n",
    "# 所以这里我们还是只划分一次数据集\n",
    "from sklearn.model_selection import train_test_split\n",
    "X = data.drop(['Credit Default'], axis=1)  # 特征，axis=1表示按列删除\n",
    "y = data['Credit Default'] # 标签\n",
    "# 按照8:2划分训练集和测试集\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)  # 80%训练集，20%测试集\n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier #随机森林分类器\n",
    "\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score # 用于评估分类器性能的指标\n",
    "from sklearn.metrics import classification_report, confusion_matrix #用于生成分类报告和混淆矩阵\n",
    "import warnings #用于忽略警告信息\n",
    "warnings.filterwarnings(\"ignore\") # 忽略所有警告信息\n",
    "# --- 1. 默认参数的随机森林 ---\n",
    "# 评估基准模型，这里确实不需要验证集\n",
    "print(\"--- 1. 默认参数随机森林 (训练集 -> 测试集) ---\")\n",
    "import time # 这里介绍一个新的库，time库，主要用于时间相关的操作，因为调参需要很长时间，记录下会帮助后人知道大概的时长\n",
    "start_time = time.time() # 记录开始时间\n",
    "rf_model = RandomForestClassifier(random_state=42)\n",
    "rf_model.fit(X_train, y_train) # 在训练集上训练\n",
    "rf_pred = rf_model.predict(X_test) # 在测试集上预测\n",
    "end_time = time.time() # 记录结束时间\n",
    "\n",
    "print(f\"训练与预测耗时: {end_time - start_time:.4f} 秒\")\n",
    "print(\"\\n默认随机森林 在测试集上的分类报告：\")\n",
    "print(classification_report(y_test, rf_pred))\n",
    "print(\"默认随机森林 在测试集上的混淆矩阵：\")\n",
    "print(confusion_matrix(y_test, rf_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shap\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# 初始化 SHAP 解释器\n",
    "explainer = shap.TreeExplainer(rf_model)\n",
    "\n",
    "# 计算 SHAP 值（基于测试集），这个shap_values是一个numpy数组，表示每个特征对每个样本的贡献值\n",
    "# 这里大家先知道这是个numpy数组即可，我们后面学习完numpy在来回头解读这个值\n",
    "shap_values = explainer.shap_values(X_test) # 这个计算耗时"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[ 9.07465700e-03, -9.07465700e-03],\n",
       "        [ 7.21456498e-03, -7.21456498e-03],\n",
       "        [ 4.55189444e-02, -4.55189444e-02],\n",
       "        ...,\n",
       "        [ 7.12857198e-05, -7.12857198e-05],\n",
       "        [ 4.67733508e-05, -4.67733508e-05],\n",
       "        [ 1.61298135e-04, -1.61298135e-04]],\n",
       "\n",
       "       [[-1.02606871e-02,  1.02606871e-02],\n",
       "        [ 1.85572634e-02, -1.85572634e-02],\n",
       "        [-1.64992848e-02,  1.64992848e-02],\n",
       "        ...,\n",
       "        [ 2.00070852e-04, -2.00070852e-04],\n",
       "        [ 5.11798841e-05, -5.11798841e-05],\n",
       "        [ 1.02827796e-04, -1.02827796e-04]],\n",
       "\n",
       "       [[ 3.21529115e-03, -3.21529115e-03],\n",
       "        [ 1.28184070e-02, -1.28184070e-02],\n",
       "        [ 1.02124914e-01, -1.02124914e-01],\n",
       "        ...,\n",
       "        [ 1.73012306e-04, -1.73012306e-04],\n",
       "        [ 4.74133256e-05, -4.74133256e-05],\n",
       "        [ 1.26753231e-04, -1.26753231e-04]],\n",
       "\n",
       "       ...,\n",
       "\n",
       "       [[ 1.15222741e-03, -1.15222741e-03],\n",
       "        [-1.71843266e-02,  1.71843266e-02],\n",
       "        [-3.04994337e-02,  3.04994337e-02],\n",
       "        ...,\n",
       "        [ 1.44859329e-04, -1.44859329e-04],\n",
       "        [ 1.80111014e-05, -1.80111014e-05],\n",
       "        [ 1.30107512e-04, -1.30107512e-04]],\n",
       "\n",
       "       [[ 1.29249120e-03, -1.29249120e-03],\n",
       "        [ 5.66948438e-03, -5.66948438e-03],\n",
       "        [ 2.49050264e-02, -2.49050264e-02],\n",
       "        ...,\n",
       "        [ 2.50590715e-06, -2.50590715e-06],\n",
       "        [ 4.68839113e-05, -4.68839113e-05],\n",
       "        [ 1.15002997e-05, -1.15002997e-05]],\n",
       "\n",
       "       [[-1.12640555e-03,  1.12640555e-03],\n",
       "        [ 1.42648293e-02, -1.42648293e-02],\n",
       "        [ 4.74790019e-02, -4.74790019e-02],\n",
       "        ...,\n",
       "        [ 6.19451775e-05, -6.19451775e-05],\n",
       "        [ 3.30996384e-05, -3.30996384e-05],\n",
       "        [ 4.45219920e-05, -4.45219920e-05]]])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shap_values "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 9.07465700e-03, -9.07465700e-03],\n",
       "       [ 7.21456498e-03, -7.21456498e-03],\n",
       "       [ 4.55189444e-02, -4.55189444e-02],\n",
       "       [ 3.47666501e-04, -3.47666501e-04],\n",
       "       [ 2.57821493e-04, -2.57821493e-04],\n",
       "       [ 2.00758099e-03, -2.00758099e-03],\n",
       "       [-7.54175659e-03,  7.54175659e-03],\n",
       "       [-1.35324163e-03,  1.35324163e-03],\n",
       "       [-7.08191659e-04,  7.08191659e-04],\n",
       "       [-6.06829865e-03,  6.06829865e-03],\n",
       "       [-1.90501403e-03,  1.90501403e-03],\n",
       "       [ 1.44384291e-02, -1.44384291e-02],\n",
       "       [-4.91452434e-02,  4.91452434e-02],\n",
       "       [ 6.28172371e-03, -6.28172371e-03],\n",
       "       [-1.64613559e-02,  1.64613559e-02],\n",
       "       [-6.04576031e-01,  6.04576031e-01],\n",
       "       [ 4.58074016e-04, -4.58074016e-04],\n",
       "       [-1.95125086e-05,  1.95125086e-05],\n",
       "       [-1.47478232e-05,  1.47478232e-05],\n",
       "       [ 6.27274034e-04, -6.27274034e-04],\n",
       "       [-1.26003035e-05,  1.26003035e-05],\n",
       "       [-3.58303017e-04,  3.58303017e-04],\n",
       "       [ 7.89740644e-05, -7.89740644e-05],\n",
       "       [ 2.08492876e-04, -2.08492876e-04],\n",
       "       [ 5.52330472e-06, -5.52330472e-06],\n",
       "       [ 4.11019037e-04, -4.11019037e-04],\n",
       "       [ 7.15614011e-06, -7.15614011e-06],\n",
       "       [ 1.07037925e-04, -1.07037925e-04],\n",
       "       [ 7.12857198e-05, -7.12857198e-05],\n",
       "       [ 4.67733508e-05, -4.67733508e-05],\n",
       "       [ 1.61298135e-04, -1.61298135e-04]])"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shap_values[0,:,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(31, 2)"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shap_values[0,:,:].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这个对应的是（特征数，类别数目）----每个特征对2个目标类别的shap值贡献\n",
    "\n",
    "所以这个值对应的就是这个样本对应的这个特征对2个目标类别的shap值贡献"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1500, 31, 2)"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 三个维度\n",
    "# 第一个维度是样本数\n",
    "# 第二个维度是特征数\n",
    "# 第三个维度是类别数\n",
    "shap_values.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 9.07465700e-03,  7.21456498e-03,  4.55189444e-02, ...,\n",
       "         7.12857198e-05,  4.67733508e-05,  1.61298135e-04],\n",
       "       [-1.02606871e-02,  1.85572634e-02, -1.64992848e-02, ...,\n",
       "         2.00070852e-04,  5.11798841e-05,  1.02827796e-04],\n",
       "       [ 3.21529115e-03,  1.28184070e-02,  1.02124914e-01, ...,\n",
       "         1.73012306e-04,  4.74133256e-05,  1.26753231e-04],\n",
       "       ...,\n",
       "       [ 1.15222741e-03, -1.71843266e-02, -3.04994337e-02, ...,\n",
       "         1.44859329e-04,  1.80111014e-05,  1.30107512e-04],\n",
       "       [ 1.29249120e-03,  5.66948438e-03,  2.49050264e-02, ...,\n",
       "         2.50590715e-06,  4.68839113e-05,  1.15002997e-05],\n",
       "       [-1.12640555e-03,  1.42648293e-02,  4.74790019e-02, ...,\n",
       "         6.19451775e-05,  3.30996384e-05,  4.45219920e-05]])"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 比如我想取出所有样本对第一个类别的贡献值\n",
    "shap_values[:,:,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # --- 1. SHAP 特征重要性条形图 (Summary Plot - Bar) ---\n",
    "# print(\"--- 1. SHAP 特征重要性条形图 ---\")\n",
    "# shap.summary_plot(shap_values[:, :, 0], X_test, plot_type=\"bar\",show=False)  #  这里的show=False表示不直接显示图形,这样可以继续用plt来修改元素，不然就直接输出了\n",
    "# plt.title(\"SHAP Feature Importance (Bar Plot)\")\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此时可以理解为什么shap.summary_plot中第一个参数是所有样本对预测类别的shap值了。\n",
    "\n",
    "传入的 SHAP 值 (shap_values[:, :, 0]) 和特征数据 (X_test) 在维度上需要高度一致和对应。\n",
    "\n",
    "- shap_values[:, :, 0] 的每一行代表的是 一个特定样本每个特征对于预测类别的贡献值（SHAP 值）。缺乏特征本身的值\n",
    "- X_test 的每一行代表的也是同一个特定样本的特征值。\n",
    "\n",
    "这二者组合后，就可以组合（特征数，特征值，shap值）构成shap图的基本元素\n",
    "\n",
    "上面这些就是我对于shap图的理解，去年很多同学在这里经常和我说是借助ai没办法对这里debug，实际上是因为没有理解shap"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "DL",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
